import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import math

__all__ = ['NormLinear', 'FixLinear', 'FixNormLinear']

class NormLinear(nn.Linear):
    """
        L2-normalization for both weight and input
    """
    def __init__(self, in_features: int, out_features: int, 
    bias: bool = False):
        super(NormLinear, self).__init__(in_features, out_features, bias)

    def forward(self, input):
        input = F.normalize(input, dim=1)
        weight = F.normalize(self.weight, dim=1)
        return F.linear(input, weight, self.bias)

class FixLinear(nn.Linear):
    def __init__(self, in_features: int, out_features: int, weight, bias: bool = False):
        super(FixLinear, self).__init__(in_features, out_features, bias)
        self.weight = Parameter(torch.Tensor(weight))
        self.weight.requires_grad = False

class FixNormLinear(NormLinear):
    def __init__(self, in_features: int, out_features: int, weight, bias: bool = False):
        super(FixNormLinear, self).__init__(in_features, out_features, bias)
        self.weight = Parameter(torch.Tensor(weight))
        self.weight.requires_grad = False


class InitLinear(nn.Linear):
    def __init__(self, in_features: int, out_features: int, weight, bias: bool = False, scale=1.0):
        super(InitLinear, self).__init__(in_features, out_features, bias)
        self.weight = Parameter(torch.Tensor(weight) * scale)


class InitNormLinear(NormLinear):
    def __init__(self, in_features: int, out_features: int, weight, bias: bool = False):
        super(InitNormLinear, self).__init__(in_features, out_features, bias)
        self.weight = Parameter(torch.Tensor(weight))




class Classifier(nn.Module):
    def __init__(self, in_features, num_classes, mode='none', weight=None, bias=False, need_avg=False, scale=1.0):
        super(Classifier, self).__init__()
        self.need_avg = need_avg
        if self.need_avg:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.mode = 'norm' if mode.endswith('norm') else 'linear'
        if mode == 'linear' and weight is None:
            self.linear = nn.Linear(in_features, num_classes, bias=bias)
            nn.init.kaiming_normal_(self.linear.weight)
        elif mode == 'norm' and weight is None:
            self.linear = NormLinear(in_features, num_classes, bias=bias)
        elif mode == 'fixnorm' and weight is not None:
            self.linear = FixNormLinear(in_features, num_classes, weight)
        elif mode == 'fixlinear' and weight is not None:
            self.linear = FixLinear(in_features, num_classes, weight)
        elif mode == 'initlinear' and weight is not None:
            self.linear = InitLinear(in_features, num_classes, weight, scale=scale)
        elif mode == 'initnorm' and weight is not None:
            self.linear = InitNormLinear(in_features, num_classes, weight)
        else:
            raise NotImplementedError
        
        # self.linear.weight.requires_grad = False

    def change(self, mode='none'):
        self.mode = mode

    def forward(self, x, labels=None, adjusted=False, eps=0.75):
        if self.need_avg:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
        if self.mode == 'norm':
            input = F.normalize(x, dim=1)
            weight = F.normalize(self.linear.weight, dim=1)
            logits =  F.linear(input, weight)
            feature_norms = torch.norm(x, dim=1).view(-1, 1).repeat(1, logits.size(1))
            if adjusted:
                # label_one_hot = F.one_hot(labels, logits.size()[1]).float().to(logits.device)
                # target_logits = logits * label_one_hot
                # mask = (target_logits < -1 + eps).float()
                # mask_logits = logits * mask
                # c = -(1 + math.sqrt(1-(1-eps)**2)/(1-eps) * mask_logits / torch.sqrt(1-mask_logits**2)) * mask
                # logits = logits + c.detach() * logits
                logits = feature_norms.detach() * logits
        else:
            logits =  F.linear(x, self.linear.weight)
        return logits


    def margin(self):
        assert hasattr(self, 'linear')
        weight = self.linear.weight
        norm = torch.sqrt(torch.sum(weight ** 2, dim=1))
        min_norm = torch.min(norm)
        ratio = torch.max(norm) / torch.min(norm)

        tmp = F.normalize(weight, dim=1)
        similarity = torch.matmul(tmp, tmp.transpose(1, 0)) - 2 * torch.eye(tmp.size(0), device=weight.device)
        return torch.acos(torch.max(similarity)).item() / math.pi * 180, min_norm, ratio.item()